{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "from torch.utils.data import Dataset\n",
    "from tqdm.auto import tqdm\n",
    "import json \n",
    "from pathlib import Path\n",
    "from dataclasses import dataclass, field\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@dataclass\n",
    "class ImageTextData:\n",
    "    image_path: Path\n",
    "    text: str\n",
    "\n",
    "class ImageTextDataset(Dataset):\n",
    "    def __init__(self, data_size:str, language:str) -> None:\n",
    "        super().__init__()\n",
    "        self.data_size = data_size\n",
    "        self.language = language\n",
    "        if data_size =='small':\n",
    "            self.image_dir = Path(\"data/jackyhate/text-to-image-2M/data_1024_10K\")\n",
    "            self.text_file = Path(\"data/jackyhate/text-to-image-2M/data_1024_10K.json\")\n",
    "\n",
    "        elif data_size == 'large':\n",
    "            self.image_dir = Path(\"data/jackyhate/unzip2mdata\")\n",
    "            self.text_file = Path(\"data/jackyhate/text-to-image-2M/data_1024_2m.json\")\n",
    "\n",
    "        self.text_data = self.load_text_data()\n",
    "        \n",
    "\n",
    "\n",
    "    def load_text_data(self):\n",
    "        with open(self.text_file, 'r') as fin:\n",
    "            all_data = fin.readlines()\n",
    "            all_data = [json.loads(data) for data in all_data]\n",
    "        return all_data\n",
    "    \n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.text_data)\n",
    "    \n",
    "    def __getitem__(self, idx) -> ImageTextData:\n",
    "        temp_data = self.text_data[idx]\n",
    "        image_path = self.image_dir / f\"{temp_data.get('file_path')}.jpg\"\n",
    "        \n",
    "        if self.language =='zh':\n",
    "            text = temp_data.get('zh_prompt', None)\n",
    "            if text is None:\n",
    "                text = temp_data.get('prompt', None)\n",
    "\n",
    "        else:\n",
    "            text = temp_data.get('prompt', None)\n",
    "\n",
    "        return ImageTextData(image_path, text)\n",
    "    \n",
    "\n",
    "\n",
    "test_small_dataset = ImageTextDataset('small', 'en')\n",
    "len(test_small_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Image.open(test_small_dataset[1].image_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_large_dataset = ImageTextDataset('large', 'en')\n",
    "len(test_large_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Image.open(test_large_dataset[0].image_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from embedding.data import ImageTextData, ImageTextDataset\n",
    "from PIL import Image\n",
    "\n",
    "itd = ImageTextDataset('large', 'zh')\n",
    "len(itd)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = itd[0]\n",
    "a"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Image.open(a.image_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "hz_net_v2",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
